{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import yaml\n",
    "\n",
    "from hfnet.datasets.cmu import Cmu\n",
    "from hfnet.evaluation.localization import Localization\n",
    "from hfnet.evaluation.utils.db_management import read_query_list\n",
    "from hfnet.evaluation.loaders import export_loader\n",
    "from hfnet.settings import DATA_PATH, EXPER_PATH\n",
    "\n",
    "from utils import plot_matches\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "slice_ = 'slice2'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run one of the three configs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### NV+SP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_global = {\n",
    "    'db_name': 'globaldb_netvlad.pkl',\n",
    "    'experiment': 'netvlad/cmu/'+slice_,\n",
    "    'predictor': export_loader, \n",
    "    'has_keypoints': False, \n",
    "    'has_descriptors': False, \n",
    "    'pca_dim': 1024,\n",
    "    'num_prior': 10,\n",
    "}\n",
    "config_local = {\n",
    "    'db_name': 'localdb_superpoint.pkl',\n",
    "    'experiment': 'superpoint/cmu/'+slice_,\n",
    "    'predictor': export_loader,\n",
    "    'has_keypoints': True,\n",
    "    'has_descriptors': True,\n",
    "    'binarize': False,\n",
    "#    'do_nms': True,\n",
    "#    'nms_thresh': 4,\n",
    "    'num_features': 2000,\n",
    "    'ratio_thresh': 0.9,\n",
    "}\n",
    "model = 'sp_model'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### HF-Net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_global = {\n",
    "    'db_name': 'globaldb_hfnet.pkl',\n",
    "    'experiment': 'hfnet/cmu/'+slice_,\n",
    "    'predictor': export_loader, \n",
    "    'has_keypoints': False, \n",
    "    'has_descriptors': False, \n",
    "    'pca_dim': 1024,\n",
    "    'num_prior': 10,\n",
    "}\n",
    "config_local = {\n",
    "    'db_name': 'localdb_hfnet.pkl',\n",
    "    'experiment': 'hfnet/cmu/'+slice_,\n",
    "    'predictor': export_loader,\n",
    "    'has_keypoints': True,\n",
    "    'has_descriptors': True,\n",
    "    'binarize': False,\n",
    "#    'do_nms': True,\n",
    "#    'nms_thresh': 4,\n",
    "    'num_features': 2000,\n",
    "    'ratio_thresh': 0.9,\n",
    "}\n",
    "model = 'sp_model'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### NV+SIFT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_global = {\n",
    "    'db_name': 'globaldb_netvlad_sift-retry.pkl',\n",
    "    'experiment': 'netvlad/cmu_resize-1024/'+slice_,\n",
    "    'predictor': export_loader, \n",
    "    'has_keypoints': False, \n",
    "    'has_descriptors': False, \n",
    "    'pca_dim': 1024,\n",
    "    'num_prior': 10,\n",
    "}\n",
    "config_local = {\n",
    "    'db_name': 'localdb_sift.pkl',\n",
    "    'colmap_db': 'colmapdb_sift_database.db',\n",
    "    'colmap_db_queries': 'colmapdb_sift_queries.db',\n",
    "    'broken_paths': True,\n",
    "    'ratio_thresh': 0.7,\n",
    "}\n",
    "model = 'sift_model'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup localization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[04/07/2019 13:44:30 INFO] Importing COLMAP model sp_model\n",
      "[04/07/2019 13:44:31 INFO] Number of images: 468\n",
      "Number of points: 47560\n",
      "Median keypoints per image: 1319.0\n",
      "Ratio of matched keypoints: 0.376\n",
      "\n",
      "[04/07/2019 13:44:31 INFO] Importing global and local databases\n",
      "[04/07/2019 13:44:32 INFO] Indexing descriptors\n"
     ]
    }
   ],
   "source": [
    "config_pose = {\n",
    "    'reproj_error': 5,\n",
    "    'min_inliers': 12,\n",
    "}\n",
    "config = {'global': config_global, 'local': config_local, 'pose': config_pose}\n",
    "loc = Localization('cmu/'+slice_, model, config)\n",
    "queries = read_query_list(Path(loc.base_path, slice_+'.queries_with_intrinsics.txt'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optionally: isolate successful or failed queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_file = f'cmu/eval_name_{slice_}.yaml'\n",
    "with open(Path(EXPER_PATH, 'eval', eval_file), 'r') as f:\n",
    "    failures = yaml.load(f)['metrics']['failure']\n",
    "#queries = [queries[f] for f in failures]  # failures\n",
    "queries = [queries[i] for i in range(len(queries)) if i not in set(failures)]  # success"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ready query dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.RandomState(0).shuffle(queries)\n",
    "query_dataset = Cmu(**{'resize_max': 1024,\n",
    "                       'image_names': [q.name for q in queries], 'prefix': slice_})\n",
    "def get_image(name):\n",
    "    path = Path(DATA_PATH, query_dataset.dataset_folder, slice_, name)\n",
    "    return cv2.imread(path.as_posix())[..., ::-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Localize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query_iter = query_dataset.get_test_set()\n",
    "for i, query_info, query_data in zip(range(5), queries, query_iter):\n",
    "    results, debug = loc.localize(query_info, query_data, debug=True)\n",
    "    s = f'{i} {\"Success\" if results.success else \"Failure\"}, inliers {results.num_inliers:^4}, ' \\\n",
    "        + f'ratio {results.inlier_ratio:.3f}, landmarks {len(debug[\"matching\"][\"lm_frames\"]):>4}, ' \\\n",
    "        + f'spl {debug[\"index_success\"]:>2}, places {[len(p) for p in debug[\"places\"]]:}, ' \\\n",
    "        + f'pos {[f\"{n:.1f}\" for n in results.T[:3, 3]]}'\n",
    "    print(s)\n",
    "    \n",
    "    sorted_frames, counts = np.unique(\n",
    "        [debug['matching']['lm_frames'][m2] for m1, m2 in debug['matches'][debug['inliers']]],\n",
    "        return_counts=True)\n",
    "    best_id = sorted_frames[np.argmax(counts)]\n",
    "\n",
    "    query_image = get_image(query_info.name)\n",
    "    best_image = get_image(loc.images[best_id].name)\n",
    "    best_matches_inliers = [(m1, debug['matching']['lm_indices'][m2]) \n",
    "                            for m1, m2 in debug['matches'][debug['inliers']] \n",
    "                            if debug['matching']['lm_frames'][m2] == best_id]\n",
    "    \n",
    "    plot_matches(\n",
    "        query_image, debug['query_item'].keypoints,\n",
    "        best_image, loc.local_db[best_id].keypoints,\n",
    "        np.array(best_matches_inliers), color=(0, 1., 0),\n",
    "        dpi=100, ylabel=str(i), thickness=1.)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Platform LSF remote_gpu_hfnet",
   "language": "",
   "name": "rik_lsf_remote_gpu_hfnet"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
